/*
 *  Copyright 1999-2019 Seata.io Group.
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package com.auditlog.datasource;

import com.auditlog.datasource.context.ConnectionContext;
import com.auditlog.datasource.db.SqlType;
import com.auditlog.datasource.struct.SqlMeta;
import com.auditlog.sql.parser.CachedSqlParser;
import com.auditlog.datasource.context.ContextHolder;
import com.auditlog.datasource.table.TableMeta;
import com.auditlog.datasource.table.cache.TableMetaCacheFactory;
import com.auditlog.proxy.InvocationHandlerWrapper;
import com.auditlog.proxy.InvocationHanlderInterceptor;
import com.auditlog.proxy.ProxyFactory;
import com.auditlog.util.JdbcConstants;
import lombok.Getter;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.statement.insert.Insert;

import java.lang.reflect.Method;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.Executor;

/**
 * The type Abstract connection proxy.
 *
 * @author sharajava
 */
@Slf4j
public abstract class AbstractConnectionProxy implements Connection {

    /**
     * 关联的数据源代理
     */
    protected DataSourceProxy dataSourceProxy;

    /**
     * 目标连接
     */
    protected Connection targetConnection;

    /**
     * 存放一些sql执行的信息
     */
    @Getter
    protected ConnectionContext connectionContext = ContextHolder.getConnectionContext();

    private String addBatchErrorMsg = "DML Returning cannot be batched";

    public AbstractConnectionProxy(DataSourceProxy dataSourceProxy, Connection targetConnection) {
        this.dataSourceProxy = dataSourceProxy;
        this.targetConnection = targetConnection;
    }

    public DataSourceProxy getDataSourceProxy() {
        return dataSourceProxy;
    }

    public Connection getTargetConnection() {
        return targetConnection;
    }

    public String getDbType() {
        return dataSourceProxy.getDbType();
    }

    @Override
    public Statement createStatement() throws SQLException {
        Statement targetStatement = getTargetConnection().createStatement();
        return new StatementProxy(this, targetStatement);
    }

    @SneakyThrows
    @Override
    public PreparedStatement prepareStatement(String sql) throws SQLException {
        String dbType = getDbType();
        // support oracle 10.2+
        PreparedStatement targetPreparedStatement = null;

        SqlMeta sqlMeta = CachedSqlParser.parse(sql);
        SqlType sqlType = sqlMeta.getSqlType();

        if (sqlType == SqlType.EXCEPTION) {
            throw new SQLException("sql解析出现异常");
        }

        if (sqlType == SqlType.MULTI) {
            this.getConnectionContext().setPrepareMultiSql(true);
            List<SqlType> sqlTypes = sqlMeta.getSqlTypes();
            if (sqlTypes.contains(SqlType.INSERT_DUPLICATE)) {
                // 因为当没指定主键时，根据getGenerateKeys获取到的不准确
                throw new IllegalStateException("不支持多个sql同时执行中包含insert on duplicate key操作:" + sql);
            }
        }
        if (sqlType == SqlType.INSERT || sqlType == SqlType.INSERT_DUPLICATE) {
            Insert insert = (Insert) sqlMeta.getStatements().get(0);
            TableMeta tableMeta = TableMetaCacheFactory.getTableMetaCache(dbType).getTableMeta(getTargetConnection(),
                    insert.getTable().getName(), getDataSourceProxy().getResourceId());
            String[] pkNameArray = new String[tableMeta.getPrimaryKeyOnlyName().size()];
            tableMeta.getPrimaryKeyOnlyName().toArray(pkNameArray);
            if (insert.getSelect() == null || insert.getSelect() != null && !this.getDbType().equals(JdbcConstants.ORACLE_STR)) {
                // oracle在insert select中，如果指定pkNameArray执行时会抛异常【sql未正常结束】
                if (this.getConnectionContext().isAutoAssignGeneratedKeys()) {
                    targetPreparedStatement = getTargetConnection().prepareStatement(sql, pkNameArray);
                }
            }
        }
        if (targetPreparedStatement == null) {
            targetPreparedStatement = getTargetConnection().prepareStatement(sql);
        }
        PreparedStatementProxy preparedStatementProxy = new PreparedStatementProxy(this, targetPreparedStatement, sql);
        if (this.getDbType().equals(JdbcConstants.ORACLE_STR)) {
            return (PreparedStatement) ProxyFactory.getProxy(preparedStatementProxy, new InvocationHanlderInterceptor() {
                @Override
                public void onException(InvocationHandlerWrapper handlerWrapper, Method method, Object[] args, Throwable throwable) throws Throwable {
                    log.debug("error happened:{}, recreate preparestatement", throwable.getCause().getMessage());
                    if (method.getName().equals("addBatch") && throwable.getCause() instanceof SQLException &&
                            throwable.getCause().getMessage().contains(addBatchErrorMsg)) {
                        PreparedStatementProxy statementProxy = (PreparedStatementProxy) handlerWrapper.getTarget();
                        ConnectionProxy connectionProxy = statementProxy.getConnectionProxy();
                        connectionProxy.getConnectionContext().setAutoAssignGeneratedKeys(false);
                        PreparedStatement preparedStatement = connectionProxy.prepareStatement(statementProxy.getTargetSQL());
                        // 如果是addBatch报错，这里最多设置了一次值之后就会报错
                        Map<Integer, ArrayList<Object>> parameters = statementProxy.getParameters();
                        for (Integer index : parameters.keySet()) {
                            preparedStatement.setObject(index, parameters.get(index).get(0));
                        }
                        preparedStatement.addBatch();
                        handlerWrapper.setTarget(preparedStatement);
                    } else {
                        throw throwable;
                    }
                }
            });
        }
        return preparedStatementProxy;
    }

    @Override
    public CallableStatement prepareCall(String sql) throws SQLException {
        // TOOD 判断是否有标记@auditlog注解，如果有就报错
        return targetConnection.prepareCall(sql);
    }

    @Override
    public String nativeSQL(String sql) throws SQLException {
        return targetConnection.nativeSQL(sql);
    }

    @Override
    public boolean getAutoCommit() throws SQLException {
        return targetConnection.getAutoCommit();
    }

    @Override
    public void close() throws SQLException {
        try {
            targetConnection.close();
        } finally {
            ContextHolder.clear();
        }
    }

    @Override
    public boolean isClosed() throws SQLException {
        return targetConnection.isClosed();
    }

    @Override
    public DatabaseMetaData getMetaData() throws SQLException {
        return targetConnection.getMetaData();
    }

    @Override
    public void setReadOnly(boolean readOnly) throws SQLException {
        targetConnection.setReadOnly(readOnly);

    }

    @Override
    public boolean isReadOnly() throws SQLException {
        return targetConnection.isReadOnly();
    }

    @Override
    public void setCatalog(String catalog) throws SQLException {
        targetConnection.setCatalog(catalog);

    }

    @Override
    public String getCatalog() throws SQLException {
        return targetConnection.getCatalog();
    }

    @Override
    public void setTransactionIsolation(int level) throws SQLException {
        targetConnection.setTransactionIsolation(level);

    }

    @Override
    public int getTransactionIsolation() throws SQLException {
        return targetConnection.getTransactionIsolation();
    }

    @Override
    public SQLWarning getWarnings() throws SQLException {
        return targetConnection.getWarnings();
    }

    @Override
    public void clearWarnings() throws SQLException {
        targetConnection.clearWarnings();

    }

    @Override
    public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException {
        Statement statement = targetConnection.createStatement(resultSetType, resultSetConcurrency);
        return new StatementProxy<Statement>(this, statement);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency)
            throws SQLException {
        PreparedStatement preparedStatement = targetConnection.prepareStatement(sql, resultSetType,
                resultSetConcurrency);
        return new PreparedStatementProxy(this, preparedStatement, sql);
    }

    @Override
    public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException {
        return targetConnection.prepareCall(sql, resultSetType, resultSetConcurrency);
    }

    @Override
    public Map<String, Class<?>> getTypeMap() throws SQLException {
        return targetConnection.getTypeMap();
    }

    @Override
    public void setTypeMap(Map<String, Class<?>> map) throws SQLException {
        targetConnection.setTypeMap(map);

    }

    @Override
    public void setHoldability(int holdability) throws SQLException {
        targetConnection.setHoldability(holdability);

    }

    @Override
    public int getHoldability() throws SQLException {
        return targetConnection.getHoldability();
    }

    @Override
    public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability)
            throws SQLException {
        Statement statement = targetConnection.createStatement(resultSetType, resultSetConcurrency,
                resultSetHoldability);
        return new StatementProxy<Statement>(this, statement);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency,
                                              int resultSetHoldability) throws SQLException {
        PreparedStatement preparedStatement = targetConnection.prepareStatement(sql, resultSetType,
                resultSetConcurrency, resultSetHoldability);
        return new PreparedStatementProxy(this, preparedStatement, sql);
    }

    @Override
    public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency,
                                         int resultSetHoldability) throws SQLException {
        return targetConnection.prepareCall(sql, resultSetType, resultSetConcurrency, resultSetHoldability);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException {
        PreparedStatement preparedStatement = targetConnection.prepareStatement(sql, autoGeneratedKeys);
        return new PreparedStatementProxy(this, preparedStatement, sql);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException {
        PreparedStatement preparedStatement = targetConnection.prepareStatement(sql, columnIndexes);
        return new PreparedStatementProxy(this, preparedStatement, sql);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException {
        PreparedStatement preparedStatement = targetConnection.prepareStatement(sql, columnNames);
        return new PreparedStatementProxy(this, preparedStatement, sql);
    }

    @Override
    public Clob createClob() throws SQLException {
        return targetConnection.createClob();
    }

    @Override
    public Blob createBlob() throws SQLException {
        return targetConnection.createBlob();
    }

    @Override
    public NClob createNClob() throws SQLException {
        return targetConnection.createNClob();
    }

    @Override
    public SQLXML createSQLXML() throws SQLException {
        return targetConnection.createSQLXML();
    }

    @Override
    public boolean isValid(int timeout) throws SQLException {
        return targetConnection.isValid(timeout);
    }

    @Override
    public void setClientInfo(String name, String value) throws SQLClientInfoException {
        targetConnection.setClientInfo(name, value);

    }

    @Override
    public void setClientInfo(Properties properties) throws SQLClientInfoException {
        targetConnection.setClientInfo(properties);

    }

    @Override
    public String getClientInfo(String name) throws SQLException {
        return targetConnection.getClientInfo(name);
    }

    @Override
    public Properties getClientInfo() throws SQLException {
        return targetConnection.getClientInfo();
    }

    @Override
    public Array createArrayOf(String typeName, Object[] elements) throws SQLException {
        return targetConnection.createArrayOf(typeName, elements);
    }

    @Override
    public Struct createStruct(String typeName, Object[] attributes) throws SQLException {
        return targetConnection.createStruct(typeName, attributes);
    }

    @Override
    public void setSchema(String schema) throws SQLException {
        targetConnection.setSchema(schema);

    }

    @Override
    public String getSchema() throws SQLException {
        return targetConnection.getSchema();
    }

    @Override
    public void abort(Executor executor) throws SQLException {
        targetConnection.abort(executor);

    }

    @Override
    public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException {
        targetConnection.setNetworkTimeout(executor, milliseconds);
    }

    @Override
    public int getNetworkTimeout() throws SQLException {
        return targetConnection.getNetworkTimeout();
    }

    @Override
    public <T> T unwrap(Class<T> iface) throws SQLException {
        return targetConnection.unwrap(iface);
    }

    @Override
    public boolean isWrapperFor(Class<?> iface) throws SQLException {
        return targetConnection.isWrapperFor(iface);
    }
}

